{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9a6c57cf-4250-4b4f-9942-7c72ef4e279f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "from peft import get_peft_config, get_peft_model, PrefixTuningConfig, TaskType, PeftType\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "import os\n",
    "from transformers import AutoTokenizer\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "\n",
    "device = \"cuda\"\n",
    "\n",
    "model_name_or_path = \"/data/nfs/llm/model/bloomz-560m\"\n",
    "tokenizer_name_or_path = \"/data/nfs/llm/model/bloomz-560m\"\n",
    "\n",
    "peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=30)\n",
    "\n",
    "dataset_name = \"twitter_complaints\"\n",
    "checkpoint_name = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt\".replace(\"/\", \"_\")\n",
    "text_column = \"Tweet text\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 64\n",
    "lr = 3e-2\n",
    "num_epochs = 10\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a51cbbe1-0f0d-4c63-81b7-80c06dac9431",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset raft (/home/guodong.li/data/peft/data/raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84)\n",
      "100%|██████████| 2/2 [00:00<00:00, 269.92it/s]\n",
      "Loading cached processed dataset at /home/guodong.li/data/peft/data/raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84/cache-0e20fff6b1d898ca.arrow\n",
      "Loading cached processed dataset at /home/guodong.li/data/peft/data/raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84/cache-8d14a62b8a688c19.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Unlabeled', 'complaint', 'no complaint']\n",
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['Tweet text', 'ID', 'Label', 'text_label'],\n",
      "        num_rows: 50\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['Tweet text', 'ID', 'Label', 'text_label'],\n",
      "        num_rows: 3399\n",
      "    })\n",
      "})\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Tweet text': '@HMRCcustomers No this is my first job',\n",
       " 'ID': 0,\n",
       " 'Label': 2,\n",
       " 'text_label': 'no complaint'}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "# dataset = load_dataset(\"ought/raft\", dataset_name)\n",
    "dataset = load_dataset(\"/home/guodong.li/data/peft/raft/raft.py\", dataset_name, cache_dir=\"/home/guodong.li/data/peft/data\")\n",
    "\n",
    "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
    "print(classes)\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "print(dataset)\n",
    "dataset[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b0dc139b-ad24-4a27-8fad-a179180863e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "target_max_length: 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                          \r"
     ]
    }
   ],
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "if tokenizer.pad_token_id is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "print(\"target_max_length:\", target_max_length)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    targets = [str(x) for x in examples[label_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    labels = tokenizer(targets)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i] + [tokenizer.pad_token_id]\n",
    "        # print(i, sample_input_ids, label_input_ids)\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "        labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"train\"]\n",
    "\n",
    "\n",
    "train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9c84dcdb-3d98-41c2-b66f-efab96495378",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                           \r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "          227985,   5484,    915,   2566,  74757,  64626,  12384,  44639,    613,\n",
       "           52282,   2670,  79920,   3344,   1002,    368,  17646,  14472,   8348,\n",
       "             664,    718,      4,  19036,     17,  31849,     17,   6312,     76,\n",
       "              44,  62470,     56,     91,     50,  14839,     21,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3, 227985,   5484,    915,    405, 187059,\n",
       "            2256,    664,   2550,  18833,  18607, 162467,      4,   1387,   6199,\n",
       "            3291,  23405,    613,   4657,  17082,    566,   3432,    368,  78851,\n",
       "            1185,  61273,  23181,   1553,  15596,    212, 116057,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3, 227985,   5484,\n",
       "             915,  39762,   2566,  22253,   6201,  75701,     15,    632,    718,\n",
       "            5840,  10006,   6201,  18881,    427,   3804,  19528,    267, 158974,\n",
       "            1320,    368,  10029,    632,  49666,     92,     34,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3, 227985,   5484,    915,   2566, 104565,   8695,   2089,   6140,\n",
       "          109676,  99579,   1369,    512,    368,   4570,     54,    632,    368,\n",
       "            1503, 241485, 132226,     15,    982,    727,   1152,  18100,    861,\n",
       "           32596,  77597, 168154,   1306, 132226,   4346,  87843,     17, 130462,\n",
       "             364,  32923,     89,     53,   8309,     20,     75,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3, 227985,   5484,    915,   2566,\n",
       "           14173,   2960,  29906,    387,  20706,  49337,   1369,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3, 227985,   5484,    915,   2566, 219553,  45736,\n",
       "           36876,   1713,     72,    707, 187205,  13002, 177324,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3, 227985,   5484,    915,   2566, 233938,  28518,  13716,\n",
       "             427,  28146,   1119,  17918,     17, 236706,    368, 214997,   7555,\n",
       "           48659,   5276,  21600,    343,     17,  51416,  22403,    318,   1531,\n",
       "            1306,   1130,  20934,    567, 101161, 184849,  87843,     17,   1594,\n",
       "           15231,   2052,  16642,     20,   7180,     80,     26,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "          227985,   5484,    915,   2566,     80,   2068,    479,   2566,     80,\n",
       "            1376,    878, 147587,   3904,    632,    368,   6084,  65673,  78851,\n",
       "           11736,  15527,  19082,  33151,    461,     17,  45575,  17887,    632,\n",
       "            5219,  14216,  68870,   5967,   1841,   4346,  87843,     17,   1594,\n",
       "           14512,     27,     71,   8184,     19,    290,  63748,  77658,    915,\n",
       "             210]]),\n",
       " 'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def test_preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (max_length - len(sample_input_ids)) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\"attention_mask\"][i]\n",
    "        \n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "test_dataset = dataset[\"test\"].map(\n",
    "    test_preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
    "next(iter(test_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "313387c6-82c5-42c2-9839-09256a48f3a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3, 227985,\n",
       "            5484,    915,   2566, 182441,     55,  35040,    435,  16796,  79920,\n",
       "             427,    661,   6355,     17,   5568,   2213,   3172,    960, 126355,\n",
       "            6216,   5559,  61273,  23181,  77658,    915,    210,   1936, 106863,\n",
       "               3],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3, 227985,   5484,    915,   2566,\n",
       "          112229,   9107,  53312,   3262,   1306,   1152, 157816,   2084,  44326,\n",
       "           40006,    613,  27019,  16680,  77658,    915,    210,   1936, 106863,\n",
       "               3],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3, 227985,   5484,    915,\n",
       "            2566,  24236,   2111, 178653,  11964,  20425,   3370,  10665,   1074,\n",
       "            2670,  20889,   5011,     17,  77658,    915,    210,  16449,   5952,\n",
       "               3],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3, 227985,   5484,\n",
       "             915,   2566,     44,    256,  67875,  21033,  86274,  79707,   2632,\n",
       "            9999,    427,   2150,  54036,  98091,     34, 112164,  15971,  16154,\n",
       "            5382,    861,   7220,     17,  77658,    915,    210,   1936, 106863,\n",
       "               3],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3, 227985,   5484,    915,   5161,  13500,    386,\n",
       "               4, 122906,    415,  66027,  42431,    613,  70016,    361,   2550,\n",
       "              51,  21351,    322,  17896,     15,   2550,     49,     45,   2550,\n",
       "           21450,    388,    655,   2550,     75,  14263,    916,     86,   1881,\n",
       "           36534,  53902,      4,   4346,  87843,     17, 130462,   8188,     23,\n",
       "            5949, 187347,     78, 130589,  77658,    915,    210,   1936, 106863,\n",
       "               3],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3, 227985,   5484,    915,   2566,\n",
       "           31934,   5227,   6640,  16261,  87843,     17, 130462,   9600, 169668,\n",
       "              28,  13604, 112581,     40,  77658,    915,    210,   1936, 106863,\n",
       "               3],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3, 227985,   5484,    915,   2566,  14173,   2960,  29906,    387,\n",
       "           73303,    473,   9283,  11257,    368,  64129,    361,  11571,    461,\n",
       "             490,   4283,  40067,   1620,   1130,   1186,  14881,   1002,     75,\n",
       "            1728,    368,  63049,     17,  77658,    915,    210,  16449,   5952,\n",
       "               3],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3, 227985,   5484,    915,   2566,  84152,   7643,   2111,\n",
       "             473, 243647,   3276,  18100,  16916,     17,  38641,   1306,   1369,\n",
       "          105961,   4936,  28077,     17,  77658,    915,    210,   1936, 106863,\n",
       "               3]]),\n",
       " 'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n",
       " 'labels': tensor([[  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   1936, 106863,\n",
       "               3],\n",
       "         [  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   1936, 106863,\n",
       "               3],\n",
       "         [  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,  16449,   5952,\n",
       "               3],\n",
       "         [  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   1936, 106863,\n",
       "               3],\n",
       "         [  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   1936, 106863,\n",
       "               3],\n",
       "         [  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   1936, 106863,\n",
       "               3],\n",
       "         [  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,  16449,   5952,\n",
       "               3],\n",
       "         [  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,\n",
       "            -100,   -100,   -100,   -100,   -100,   -100,   -100,   1936, 106863,\n",
       "               3]])}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(train_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dab59b89-5180-4372-a990-b089fb8abf10",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "425\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "          227985,   5484,    915,   2566,  74757,  64626,  12384,  44639,    613,\n",
       "           52282,   2670,  79920,   3344,   1002,    368,  17646,  14472,   8348,\n",
       "             664,    718,      4,  19036,     17,  31849,     17,   6312,     76,\n",
       "              44,  62470,     56,     91,     50,  14839,     21,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3, 227985,   5484,    915,    405, 187059,\n",
       "            2256,    664,   2550,  18833,  18607, 162467,      4,   1387,   6199,\n",
       "            3291,  23405,    613,   4657,  17082,    566,   3432,    368,  78851,\n",
       "            1185,  61273,  23181,   1553,  15596,    212, 116057,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3, 227985,   5484,\n",
       "             915,  39762,   2566,  22253,   6201,  75701,     15,    632,    718,\n",
       "            5840,  10006,   6201,  18881,    427,   3804,  19528,    267, 158974,\n",
       "            1320,    368,  10029,    632,  49666,     92,     34,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3, 227985,   5484,    915,   2566, 104565,   8695,   2089,   6140,\n",
       "          109676,  99579,   1369,    512,    368,   4570,     54,    632,    368,\n",
       "            1503, 241485, 132226,     15,    982,    727,   1152,  18100,    861,\n",
       "           32596,  77597, 168154,   1306, 132226,   4346,  87843,     17, 130462,\n",
       "             364,  32923,     89,     53,   8309,     20,     75,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3, 227985,   5484,    915,   2566,\n",
       "           14173,   2960,  29906,    387,  20706,  49337,   1369,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3, 227985,   5484,    915,   2566, 219553,  45736,\n",
       "           36876,   1713,     72,    707, 187205,  13002, 177324,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3, 227985,   5484,    915,   2566, 233938,  28518,  13716,\n",
       "             427,  28146,   1119,  17918,     17, 236706,    368, 214997,   7555,\n",
       "           48659,   5276,  21600,    343,     17,  51416,  22403,    318,   1531,\n",
       "            1306,   1130,  20934,    567, 101161, 184849,  87843,     17,   1594,\n",
       "           15231,   2052,  16642,     20,   7180,     80,     26,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "          227985,   5484,    915,   2566,     80,   2068,    479,   2566,     80,\n",
       "            1376,    878, 147587,   3904,    632,    368,   6084,  65673,  78851,\n",
       "           11736,  15527,  19082,  33151,    461,     17,  45575,  17887,    632,\n",
       "            5219,  14216,  68870,   5967,   1841,   4346,  87843,     17,   1594,\n",
       "           14512,     27,     71,   8184,     19,    290,  63748,  77658,    915,\n",
       "             210]]),\n",
       " 'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(len(test_dataloader))\n",
    "next(iter(test_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "15b07db0-c364-4243-9475-6bdfa1c52f64",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 1,474,560 || all params: 560,689,152 || trainable%: 0.26299064191632515\n"
     ]
    }
   ],
   "source": [
    "# creating model\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "abc071bf-9fce-4f2b-bc45-d71af1928015",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PeftModelForCausalLM(\n",
       "  (base_model): BloomForCausalLM(\n",
       "    (transformer): BloomModel(\n",
       "      (word_embeddings): Embedding(250880, 1024)\n",
       "      (word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "      (h): ModuleList(\n",
       "        (0): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (1): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (2): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (3): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (4): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (5): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (6): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (7): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (8): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (9): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (10): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (11): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (12): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (13): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (14): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (15): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (16): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (17): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (18): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (19): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (20): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (21): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (22): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "        (23): BloomBlock(\n",
       "          (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (self_attention): BloomAttention(\n",
       "            (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): BloomMLP(\n",
       "            (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (gelu_impl): BloomGelu()\n",
       "            (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (lm_head): Linear(in_features=1024, out_features=250880, bias=False)\n",
       "  )\n",
       "  (prompt_encoder): ModuleDict(\n",
       "    (default): PrefixEncoder(\n",
       "      (embedding): Embedding(30, 49152)\n",
       "    )\n",
       "  )\n",
       "  (word_embeddings): Embedding(250880, 1024)\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "73b9c2c5-fcc7-4a18-b206-4a12e88c107b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'default': PrefixTuningConfig(peft_type=<PeftType.PREFIX_TUNING: 'PREFIX_TUNING'>, auto_mapping=None, base_model_name_or_path='/data/nfs/llm/model/bloomz-560m', revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, num_virtual_tokens=30, token_dim=1024, num_transformer_submodules=1, num_attention_heads=16, num_layers=24, encoder_hidden_size=1024, prefix_projection=False)}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.peft_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7e083a33-e294-4c25-a391-10981dc7ddb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model\n",
    "# optimizer and lr scheduler\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2b361e7a-9e02-45d7-8da2-5fcb281ecd92",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:01<00:00,  5.82it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0: train_ppl=tensor(1.0067e+10, device='cuda:0') train_epoch_loss=tensor(23.0325, device='cuda:0') eval_ppl=tensor(440.2399, device='cuda:0') eval_epoch_loss=tensor(6.0873, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 11.53it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=1: train_ppl=tensor(187.8570, device='cuda:0') train_epoch_loss=tensor(5.2357, device='cuda:0') eval_ppl=tensor(100.1909, device='cuda:0') eval_epoch_loss=tensor(4.6071, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 11.53it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=2: train_ppl=tensor(66.6660, device='cuda:0') train_epoch_loss=tensor(4.1997, device='cuda:0') eval_ppl=tensor(45.0488, device='cuda:0') eval_epoch_loss=tensor(3.8077, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 11.51it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=3: train_ppl=tensor(37.7751, device='cuda:0') train_epoch_loss=tensor(3.6316, device='cuda:0') eval_ppl=tensor(20.0135, device='cuda:0') eval_epoch_loss=tensor(2.9964, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 11.55it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=4: train_ppl=tensor(14.8737, device='cuda:0') train_epoch_loss=tensor(2.6996, device='cuda:0') eval_ppl=tensor(8.6689, device='cuda:0') eval_epoch_loss=tensor(2.1597, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 11.55it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=5: train_ppl=tensor(5.6210, device='cuda:0') train_epoch_loss=tensor(1.7265, device='cuda:0') eval_ppl=tensor(3.3080, device='cuda:0') eval_epoch_loss=tensor(1.1963, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 11.54it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=6: train_ppl=tensor(2.5512, device='cuda:0') train_epoch_loss=tensor(0.9366, device='cuda:0') eval_ppl=tensor(1.8660, device='cuda:0') eval_epoch_loss=tensor(0.6238, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 11.06it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=7: train_ppl=tensor(1.7045, device='cuda:0') train_epoch_loss=tensor(0.5333, device='cuda:0') eval_ppl=tensor(1.4844, device='cuda:0') eval_epoch_loss=tensor(0.3950, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 11.54it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=8: train_ppl=tensor(1.4116, device='cuda:0') train_epoch_loss=tensor(0.3447, device='cuda:0') eval_ppl=tensor(1.3491, device='cuda:0') eval_epoch_loss=tensor(0.2994, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 11.49it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=9: train_ppl=tensor(1.3204, device='cuda:0') train_epoch_loss=tensor(0.2780, device='cuda:0') eval_ppl=tensor(1.3255, device='cuda:0') eval_epoch_loss=tensor(0.2818, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# training and evaluation\n",
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        #         print(batch)\n",
    "        #         print(batch[\"input_ids\"].shape)\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.detach().float()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.detach().float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
    "    eval_ppl = torch.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataloader)\n",
    "    train_ppl = torch.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "826c92ba-0e19-4a43-ba89-df83034c2887",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hey @nytimes your link to cancel my subscription isn't working and nobody is answering the chat. Please don't play that kind of stupid game.\n",
      "{'input_ids': tensor([[227985,   5484,    915,  54078,   2566,   7782,  24502,   2632,   8989,\n",
      "            427,  36992,   2670, 140711,  21994,  10789,    530,  88399,    632,\n",
      "         183542,    368,  44799,     17,  29901,   5926,   7229,    861,  11596,\n",
      "            461,  78851,  14775,     17,  77658,    915,    210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[227985,   5484,    915,  54078,   2566,   7782,  24502,   2632,   8989,\n",
      "            427,  36992,   2670, 140711,  21994,  10789,    530,  88399,    632,\n",
      "         183542,    368,  44799,     17,  29901,   5926,   7229,    861,  11596,\n",
      "            461,  78851,  14775,     17,  77658,    915,    210,   1936, 106863,\n",
      "              3]], device='cuda:0')\n",
      "[\"Tweet text : Hey @nytimes your link to cancel my subscription isn't working and nobody is answering the chat. Please don't play that kind of stupid game. Label : no complaint\"]\n"
     ]
    }
   ],
   "source": [
    "# 模型评估\n",
    "model.eval()\n",
    "i = 16\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c9756276-4b79-4f2e-bf12-a092d8734651",
   "metadata": {},
   "outputs": [],
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "model.save_pretrained(peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d602335e-bd15-4e87-98c9-43f6c6d3913f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "5.7M\t/data/nfs/llm/model/bloomz-560m_PREFIX_TUNING_CAUSAL_LM/adapter_model.bin\n",
      "--------------\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "/data/nfs/llm/model/bloomz-560m_PREFIX_TUNING_CAUSAL_LM\n",
      "├── [ 390]  adapter_config.json\n",
      "├── [5.6M]  adapter_model.bin\n",
      "└── [  93]  README.md\n",
      "\n",
      "0 directories, 3 files\n"
     ]
    }
   ],
   "source": [
    "ckpt = f\"{peft_model_id}/adapter_model.bin\"\n",
    "!du -h $ckpt\n",
    "print(\"--------------\")\n",
    "!tree -h $peft_model_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "04885484-b34b-4d3e-8217-0da101dfdcda",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c3ef0558-be14-43c0-a28e-e0c623f5e2fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "@greateranglia Ok thanks...\n",
      "{'input_ids': tensor([[227985,   5484,    915,   2566,  14173,   2960,  29906,    387,  20706,\n",
      "          49337,   1369,  77658,    915,    210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[227985,   5484,    915,   2566,  14173,   2960,  29906,    387,  20706,\n",
      "          49337,   1369,  77658,    915,    210,   1936, 106863,      3]],\n",
      "       device='cuda:0')\n",
      "['Tweet text : @greateranglia Ok thanks... Label : no complaint']\n"
     ]
    }
   ],
   "source": [
    "model.to(device)\n",
    "model.eval()\n",
    "i = 4\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3a61b0a-2b61-4fcb-8e9f-2ee7096f700a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
